Uczenie Maszynowe - Lista 3¶
import bibliotek¶
%pip install ucimlrepo
%pip install pandas
%pip install numpy
%pip install ipywidgets
%pip install sweetviz
%pip install scikit-learn
%pip install matplotlib
%pip install seaborn
%pip install graphviz
Requirement already satisfied: ucimlrepo in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (0.0.3)Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: pandas in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (2.2.1) Requirement already satisfied: numpy<2,>=1.26.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from pandas) (1.26.4) Requirement already satisfied: python-dateutil>=2.8.2 in c:\users\filip\appdata\roaming\python\python312\site-packages (from pandas) (2.8.2) Requirement already satisfied: pytz>=2020.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from pandas) (2024.1) Requirement already satisfied: tzdata>=2022.7 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from pandas) (2024.1) Requirement already satisfied: six>=1.5 in c:\users\filip\appdata\roaming\python\python312\site-packages (from python-dateutil>=2.8.2->pandas) (1.16.0) Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: numpy in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (1.26.4) Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: ipywidgets in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (8.1.2) Requirement already satisfied: comm>=0.1.3 in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipywidgets) (0.2.1) Requirement already satisfied: ipython>=6.1.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipywidgets) (8.22.1) Requirement already satisfied: traitlets>=4.3.1 in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipywidgets) (5.14.1) Requirement already satisfied: widgetsnbextension~=4.0.10 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from ipywidgets) (4.0.10) Requirement already satisfied: jupyterlab-widgets~=3.0.10 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from ipywidgets) (3.0.10) Requirement already satisfied: decorator in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipython>=6.1.0->ipywidgets) (5.1.1) Requirement already satisfied: jedi>=0.16 in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipython>=6.1.0->ipywidgets) (0.19.1) Requirement already satisfied: matplotlib-inline in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipython>=6.1.0->ipywidgets) (0.1.6) Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipython>=6.1.0->ipywidgets) (3.0.43) Requirement already satisfied: pygments>=2.4.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipython>=6.1.0->ipywidgets) (2.17.2) Requirement already satisfied: stack-data in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipython>=6.1.0->ipywidgets) (0.6.3) Requirement already satisfied: colorama in c:\users\filip\appdata\roaming\python\python312\site-packages (from ipython>=6.1.0->ipywidgets) (0.4.6) Requirement already satisfied: parso<0.9.0,>=0.8.3 in c:\users\filip\appdata\roaming\python\python312\site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets) (0.8.3) Requirement already satisfied: wcwidth in c:\users\filip\appdata\roaming\python\python312\site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets) (0.2.13) Requirement already satisfied: executing>=1.2.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.0.1) Requirement already satisfied: asttokens>=2.1.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (2.4.1) Requirement already satisfied: pure-eval in c:\users\filip\appdata\roaming\python\python312\site-packages (from stack-data->ipython>=6.1.0->ipywidgets) (0.2.2) Requirement already satisfied: six>=1.12.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from asttokens>=2.1.0->stack-data->ipython>=6.1.0->ipywidgets) (1.16.0) Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: sweetviz in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (2.3.1)Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: pandas!=1.0.0,!=1.0.1,!=1.0.2,>=0.25.3 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from sweetviz) (2.2.1) Requirement already satisfied: numpy>=1.16.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from sweetviz) (1.26.4) Requirement already satisfied: matplotlib>=3.1.3 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from sweetviz) (3.8.3) Requirement already satisfied: tqdm>=4.43.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from sweetviz) (4.66.2) Requirement already satisfied: scipy>=1.3.2 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from sweetviz) (1.12.0) Requirement already satisfied: jinja2>=2.11.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from sweetviz) (3.1.3) Requirement already satisfied: importlib-resources>=1.2.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from sweetviz) (6.4.0) Requirement already satisfied: MarkupSafe>=2.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from jinja2>=2.11.1->sweetviz) (2.1.5) Requirement already satisfied: contourpy>=1.0.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib>=3.1.3->sweetviz) (1.2.0) Requirement already satisfied: cycler>=0.10 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib>=3.1.3->sweetviz) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib>=3.1.3->sweetviz) (4.49.0) Requirement already satisfied: kiwisolver>=1.3.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib>=3.1.3->sweetviz) (1.4.5) Requirement already satisfied: packaging>=20.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from matplotlib>=3.1.3->sweetviz) (23.2) Requirement already satisfied: pillow>=8 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib>=3.1.3->sweetviz) (10.2.0) Requirement already satisfied: pyparsing>=2.3.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib>=3.1.3->sweetviz) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in c:\users\filip\appdata\roaming\python\python312\site-packages (from matplotlib>=3.1.3->sweetviz) (2.8.2) Requirement already satisfied: pytz>=2020.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from pandas!=1.0.0,!=1.0.1,!=1.0.2,>=0.25.3->sweetviz) (2024.1) Requirement already satisfied: tzdata>=2022.7 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from pandas!=1.0.0,!=1.0.1,!=1.0.2,>=0.25.3->sweetviz) (2024.1) Requirement already satisfied: colorama in c:\users\filip\appdata\roaming\python\python312\site-packages (from tqdm>=4.43.0->sweetviz) (0.4.6) Requirement already satisfied: six>=1.5 in c:\users\filip\appdata\roaming\python\python312\site-packages (from python-dateutil>=2.7->matplotlib>=3.1.3->sweetviz) (1.16.0) Requirement already satisfied: scikit-learn in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (1.4.1.post1) Requirement already satisfied: numpy<2.0,>=1.19.5 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (1.26.4) Requirement already satisfied: scipy>=1.6.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (1.12.0) Requirement already satisfied: joblib>=1.2.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (1.3.2) Requirement already satisfied: threadpoolctl>=2.0.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from scikit-learn) (3.3.0) Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: matplotlib in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (3.8.3) Requirement already satisfied: contourpy>=1.0.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib) (1.2.0) Requirement already satisfied: cycler>=0.10 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib) (4.49.0) Requirement already satisfied: kiwisolver>=1.3.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib) (1.4.5) Requirement already satisfied: numpy<2,>=1.21 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib) (1.26.4) Requirement already satisfied: packaging>=20.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from matplotlib) (23.2) Requirement already satisfied: pillow>=8 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib) (10.2.0) Requirement already satisfied: pyparsing>=2.3.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in c:\users\filip\appdata\roaming\python\python312\site-packages (from matplotlib) (2.8.2) Requirement already satisfied: six>=1.5 in c:\users\filip\appdata\roaming\python\python312\site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: seaborn in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (0.13.2) Requirement already satisfied: numpy!=1.24.0,>=1.20 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from seaborn) (1.26.4) Requirement already satisfied: pandas>=1.2 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from seaborn) (2.2.1) Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from seaborn) (3.8.3) Requirement already satisfied: contourpy>=1.0.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.2.0) Requirement already satisfied: cycler>=0.10 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.49.0) Requirement already satisfied: kiwisolver>=1.3.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.5) Requirement already satisfied: packaging>=20.0 in c:\users\filip\appdata\roaming\python\python312\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (23.2) Requirement already satisfied: pillow>=8 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (10.2.0) Requirement already satisfied: pyparsing>=2.3.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.1.1) Requirement already satisfied: python-dateutil>=2.7 in c:\users\filip\appdata\roaming\python\python312\site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.8.2) Requirement already satisfied: pytz>=2020.1 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from pandas>=1.2->seaborn) (2024.1) Requirement already satisfied: tzdata>=2022.7 in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (from pandas>=1.2->seaborn) (2024.1) Requirement already satisfied: six>=1.5 in c:\users\filip\appdata\roaming\python\python312\site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.4->seaborn) (1.16.0) Note: you may need to restart the kernel to use updated packages. Requirement already satisfied: graphviz in c:\users\filip\appdata\local\programs\python\python312\lib\site-packages (0.20.3) Note: you may need to restart the kernel to use updated packages.
from ucimlrepo import fetch_ucirepo
import numpy as np
import pandas as pd
import sweetviz as sv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelBinarizer
# fetch dataset
secondary_mushroom = fetch_ucirepo(id=848)
# data (as pandas dataframes)
X = secondary_mushroom.data.features
y = secondary_mushroom.data.targets
# metadata
print(secondary_mushroom.metadata)
# variable information
print(secondary_mushroom.variables)
{'uci_id': 848, 'name': 'Secondary Mushroom', 'repository_url': 'https://archive.ics.uci.edu/dataset/848/secondary+mushroom+dataset', 'data_url': 'https://archive.ics.uci.edu/static/public/848/data.csv', 'abstract': 'Dataset of simulated mushrooms for binary classification into edible and poisonous.', 'area': 'Biology', 'tasks': ['Classification'], 'characteristics': ['Tabular'], 'num_instances': 61068, 'num_features': 20, 'feature_types': ['Real'], 'demographics': [], 'target_col': ['class'], 'index_col': None, 'has_missing_values': 'yes', 'missing_values_symbol': 'NaN', 'year_of_dataset_creation': 2021, 'last_updated': 'Wed Apr 10 2024', 'dataset_doi': '10.24432/C5FP5Q', 'creators': ['Dennis Wagner', 'D. Heider', 'Georges Hattab'], 'intro_paper': {'title': 'Mushroom data creation, curation, and simulation to support classification tasks', 'authors': 'Dennis Wagner, D. Heider, Georges Hattab', 'published_in': 'Scientific Reports', 'year': 2021, 'url': 'https://www.semanticscholar.org/paper/336be248b6f1c5d77c3c93e89f2e19e7344b0250', 'doi': None}, 'additional_info': {'summary': 'The given information is about the Secondary Mushroom Dataset, the Primary Mushroom Dataset used for the simulation and the respective metadata can be found in the zip.\n\nThis dataset includes 61069 hypothetical mushrooms with caps based on 173 species (353 mushrooms\nper species). Each mushroom is identified as definitely edible, definitely poisonous, or of\nunknown edibility and not recommended (the latter class was combined with the poisonous class).\n\nThe related Python project contains a Python module secondary_data_generation.py\nused to generate this data based on primary_data_edited.csv also found in the repository.\nBoth nominal and metrical variables are a result of randomization.\nThe simulated and ordered by species version is found in secondary_data_generated.csv.\nThe randomly shuffled version is found in secondary_data_shuffled.csv.', 'purpose': 'Inspired by the Mushroom Data Set of J. Schlimmer: url:https://archive.ics.uci.edu/ml/datasets/Mushroom.', 'funded_by': None, 'instances_represent': None, 'recommended_data_splits': None, 'sensitive_data': None, 'preprocessing_description': None, 'variable_info': 'One binary class divided in edible=e and poisonous=p (with the latter one also containing mushrooms of unknown edibility).\nTwenty remaining variables (n: nominal, m: metrical)\n1. cap-diameter (m): float number in cm\n2. cap-shape (n): bell=b, conical=c, convex=x, flat=f,\nsunken=s, spherical=p, others=o\n3. cap-surface (n): fibrous=i, grooves=g, scaly=y, smooth=s,\nshiny=h, leathery=l, silky=k, sticky=t,\nwrinkled=w, fleshy=e\n4. cap-color (n): brown=n, buff=b, gray=g, green=r, pink=p,\npurple=u, red=e, white=w, yellow=y, blue=l,\norange=o, black=k\n5. does-bruise-bleed (n): bruises-or-bleeding=t,no=f\n6. gill-attachment (n): adnate=a, adnexed=x, decurrent=d, free=e,\nsinuate=s, pores=p, none=f, unknown=?\n7. gill-spacing (n): close=c, distant=d, none=f\n8. gill-color (n): see cap-color + none=f\n9. stem-height (m): float number in cm\n10. stem-width (m): float number in mm\n11. stem-root (n): bulbous=b, swollen=s, club=c, cup=u, equal=e,\nrhizomorphs=z, rooted=r\n12. stem-surface (n): see cap-surface + none=f\n13. stem-color (n): see cap-color + none=f\n14. veil-type (n): partial=p, universal=u\n15. veil-color (n): see cap-color + none=f\n16. has-ring (n): ring=t, none=f\n17. ring-type (n): cobwebby=c, evanescent=e, flaring=r, grooved=g,\nlarge=l, pendant=p, sheathing=s, zone=z, scaly=y, movable=m, none=f, unknown=?\n18. spore-print-color (n): see cap color\n19. habitat (n): grasses=g, leaves=l, meadows=m, paths=p, heaths=h,\nurban=u, waste=w, woods=d\n20. season (n): spring=s, summer=u, autumn=a, winter=w', 'citation': None}}
name role type demographic description units \
0 class Target Categorical None None None
1 cap-diameter Feature Continuous None None None
2 cap-shape Feature Categorical None None None
3 cap-surface Feature Categorical None None None
4 cap-color Feature Categorical None None None
5 does-bruise-or-bleed Feature Categorical None None None
6 gill-attachment Feature Categorical None None None
7 gill-spacing Feature Categorical None None None
8 gill-color Feature Categorical None None None
9 stem-height Feature Continuous None None None
10 stem-width Feature Continuous None None None
11 stem-root Feature Categorical None None None
12 stem-surface Feature Categorical None None None
13 stem-color Feature Categorical None None None
14 veil-type Feature Categorical None None None
15 veil-color Feature Categorical None None None
16 has-ring Feature Categorical None None None
17 ring-type Feature Categorical None None None
18 spore-print-color Feature Categorical None None None
19 habitat Feature Categorical None None None
20 season Feature Categorical None None None
missing_values
0 no
1 no
2 no
3 yes
4 no
5 no
6 yes
7 yes
8 no
9 no
10 no
11 yes
12 yes
13 no
14 yes
15 yes
16 no
17 yes
18 yes
19 no
20 no
Analiza zbioru danych¶
Czym w ogóle jest zbiór? Jaki jest sens? W jaką stronę chcemy zbiegać z klasyfikacją?
#merge X and y
df = pd.concat([X, y], axis=1)
# data preview
df.head()
| cap-diameter | cap-shape | cap-surface | cap-color | does-bruise-or-bleed | gill-attachment | gill-spacing | gill-color | stem-height | stem-width | ... | stem-surface | stem-color | veil-type | veil-color | has-ring | ring-type | spore-print-color | habitat | season | class | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 15.26 | x | g | o | f | e | NaN | w | 16.95 | 17.09 | ... | y | w | u | w | t | g | NaN | d | w | p |
| 1 | 16.60 | x | g | o | f | e | NaN | w | 17.99 | 18.19 | ... | y | w | u | w | t | g | NaN | d | u | p |
| 2 | 14.07 | x | g | o | f | e | NaN | w | 17.80 | 17.74 | ... | y | w | u | w | t | g | NaN | d | w | p |
| 3 | 14.17 | f | h | e | f | e | NaN | w | 15.77 | 15.98 | ... | y | w | u | w | t | p | NaN | d | w | p |
| 4 | 14.64 | x | h | o | f | e | NaN | w | 16.53 | 17.20 | ... | y | w | u | w | t | p | NaN | d | w | p |
5 rows × 21 columns
#show class distribution
df['class'].value_counts()
class p 33888 e 27181 Name: count, dtype: int64
#describe the data
df.describe()
| cap-diameter | stem-height | stem-width | |
|---|---|---|---|
| count | 61069.000000 | 61069.000000 | 61069.000000 |
| mean | 6.733854 | 6.581538 | 12.149410 |
| std | 5.264845 | 3.370017 | 10.035955 |
| min | 0.380000 | 0.000000 | 0.000000 |
| 25% | 3.480000 | 4.640000 | 5.210000 |
| 50% | 5.860000 | 5.950000 | 10.190000 |
| 75% | 8.540000 | 7.740000 | 16.570000 |
| max | 62.340000 | 33.920000 | 103.910000 |
#show missing values
df.isnull().sum()
cap-diameter 0 cap-shape 0 cap-surface 14120 cap-color 0 does-bruise-or-bleed 0 gill-attachment 9884 gill-spacing 25063 gill-color 0 stem-height 0 stem-width 0 stem-root 51538 stem-surface 38124 stem-color 0 veil-type 57892 veil-color 53656 has-ring 0 ring-type 2471 spore-print-color 54715 habitat 0 season 0 class 0 dtype: int64
#show unique values in the data
df.nunique()
cap-diameter 2571 cap-shape 7 cap-surface 11 cap-color 12 does-bruise-or-bleed 2 gill-attachment 7 gill-spacing 3 gill-color 12 stem-height 2226 stem-width 4630 stem-root 5 stem-surface 8 stem-color 13 veil-type 1 veil-color 6 has-ring 2 ring-type 8 spore-print-color 7 habitat 8 season 4 class 2 dtype: int64
Jakie kolumny mogłby być tak, lub nie
#show unique values in the data of 'does_bruise_or_bleed' column
df['does-bruise-or-bleed'].unique()
array(['f', 't'], dtype=object)
df['has-ring'].unique()
array(['t', 'f'], dtype=object)
df['class'].unique()
array(['p', 'e'], dtype=object)
WNIOSEK 1¶
- niektóre kolumny można zakodować jako kolumna binarna
- Reszta kategorycznych jako one hot encoding...?
- Jeżeli bierzemy pod uwagę drzewa decyzyjne, to nie musimy się przejmować skalowaniem danych. W przypadku innych modeli, warto znormalizować dane.
Użycie AutoEDA (SweetViz)¶
# use sweetviz to generate a report
df["class"] = df["class"].map({"p": True, "e": False})
# feat_cfg = sv.FeatureConfig(force_cat=['class'])
report = sv.analyze(df, target_feat='class')
report.show_html("sweetviz_report.html")
| | [ 0%] 00:00 -> (? left)
Report sweetviz_report.html was generated! NOTEBOOK/COLAB USERS: the web browser MAY not pop up, regardless, the report IS saved in your notebook/colab files.
Wnioski z AutoEDA¶
Macierz korelacji:
- Widać, że szerokość łodygi skorelowana pozytywnie z szerokoscą kapelusza, sama wyskość łodygi nie aż tak mocno skorelowana
- Widać oczywisye rzeczy jak skorelowanie występowana piersciena z jego typem
- Oraz inne aczkolwiek trzeba zwrócić uwagę, że nie jest tp symetrczune w przypadku kategorycznych
- rozkłady są raczej skośne - mediana do wartości brakujących
- 'Olbrzymie grzyby są rzadkością" :)
Klasy są zbalansowane 45 % do 55 %. - ważne że można zbalansować wagi klas w drzewie aby zbiegać w ustaloną dla nas strone...
Na czerwono widać wartości brakujące danych cech, trzeba się zastanowić, czy wynikają z braku danej cechy jeżeli pytamy o rodzaj?
Przygotowanie danych do modelowania¶
Stratyfikowany poział na zbiór treningowy i testowy¶
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0, stratify=y)
Przygotownie potoku przetwarzania danych¶
from sklearn.preprocessing import Binarizer, MinMaxScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer
from sklearn import set_config
set_config(transform_output="pandas")
to_onehot = [
"cap-surface",
"cap-color",
"cap-surface",
"cap-color",
"gill-attachment",
"gill-spacing",
"gill-color",
"stem-root",
"stem-surface",
"stem-color",
"veil-color",
"ring-type",
"spore-print-color",
"habitat",
"season",
"does-bruise-or-bleed",
"has-ring"
]
numeric = ['cap-diameter', 'stem-height', 'stem-width']
missing_col_names = []
def is_missing(df):
df = pd.concat([df, df.notnull().astype(int).add_suffix("_present")], axis=1)
return df
onehot = OneHotEncoder(sparse_output=False, handle_unknown='ignore', dtype=bool)
minmax = MinMaxScaler()
cat_imputer = SimpleImputer(strategy='most_frequent')
num_imputer = SimpleImputer(strategy='median')
preprocessor_steps = []
preprocessor_steps.append(('missing_indicator', FunctionTransformer(func=is_missing)))
cat_tranformer = Pipeline(steps=[
('imputer', cat_imputer),
('onehot', onehot)
])
numeric_transformer = Pipeline(steps=[
('imputer', num_imputer),
# ('minmax', minmax)
])
column_transformer = ColumnTransformer(
[
('categorical', cat_tranformer, to_onehot),
('numeric', numeric_transformer, numeric)
],
)
preprocessor_steps.append(('column_transformer', column_transformer))
preprocessor = Pipeline(steps=preprocessor_steps)
preprocessor
Pipeline(steps=[('missing_indicator',
FunctionTransformer(func=<function is_missing at 0x0000022255948180>)),
('column_transformer',
ColumnTransformer(transformers=[('categorical',
Pipeline(steps=[('imputer',
SimpleImputer(strategy='most_frequent')),
('onehot',
OneHotEncoder(dtype=<class 'bool'>,
handle_unknown='ignore',
sparse_output=False))]),
['cap-surface', 'cap-color',
'cap-surface', 'cap-color',
'gill-attachment',
'gill-spacing', 'gill-color',
'stem-root', 'stem-surface',
'stem-color', 'veil-color',
'ring-type',
'spore-print-color',
'habitat', 'season',
'does-bruise-or-bleed',
'has-ring']),
('numeric',
Pipeline(steps=[('imputer',
SimpleImputer(strategy='median'))]),
['cap-diameter',
'stem-height',
'stem-width'])]))])In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
Pipeline(steps=[('missing_indicator',
FunctionTransformer(func=<function is_missing at 0x0000022255948180>)),
('column_transformer',
ColumnTransformer(transformers=[('categorical',
Pipeline(steps=[('imputer',
SimpleImputer(strategy='most_frequent')),
('onehot',
OneHotEncoder(dtype=<class 'bool'>,
handle_unknown='ignore',
sparse_output=False))]),
['cap-surface', 'cap-color',
'cap-surface', 'cap-color',
'gill-attachment',
'gill-spacing', 'gill-color',
'stem-root', 'stem-surface',
'stem-color', 'veil-color',
'ring-type',
'spore-print-color',
'habitat', 'season',
'does-bruise-or-bleed',
'has-ring']),
('numeric',
Pipeline(steps=[('imputer',
SimpleImputer(strategy='median'))]),
['cap-diameter',
'stem-height',
'stem-width'])]))])FunctionTransformer(func=<function is_missing at 0x0000022255948180>)
ColumnTransformer(transformers=[('categorical',
Pipeline(steps=[('imputer',
SimpleImputer(strategy='most_frequent')),
('onehot',
OneHotEncoder(dtype=<class 'bool'>,
handle_unknown='ignore',
sparse_output=False))]),
['cap-surface', 'cap-color', 'cap-surface',
'cap-color', 'gill-attachment',
'gill-spacing', 'gill-color', 'stem-root',
'stem-surface', 'stem-color', 'veil-color',
'ring-type', 'spore-print-color', 'habitat',
'season', 'does-bruise-or-bleed',
'has-ring']),
('numeric',
Pipeline(steps=[('imputer',
SimpleImputer(strategy='median'))]),
['cap-diameter', 'stem-height',
'stem-width'])])['cap-surface', 'cap-color', 'cap-surface', 'cap-color', 'gill-attachment', 'gill-spacing', 'gill-color', 'stem-root', 'stem-surface', 'stem-color', 'veil-color', 'ring-type', 'spore-print-color', 'habitat', 'season', 'does-bruise-or-bleed', 'has-ring']
SimpleImputer(strategy='most_frequent')
OneHotEncoder(dtype=<class 'bool'>, handle_unknown='ignore',
sparse_output=False)['cap-diameter', 'stem-height', 'stem-width']
SimpleImputer(strategy='median')
Trening modeli z wykorzystaniem przeszukiwania hiperparametrów modelu.¶
Metoda analizy wybranego drzewa decyzyjnego
from sklearn.metrics import classification_report
from IPython.display import SVG
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn import tree
import graphviz
import matplotlib.pyplot as plt
import seaborn as sns
def decision_tree_analysis(name ,model, X_test, y_test):
features = model.named_steps["dt"].feature_names_in_
# Make predictions on the test set
y_pred = model.predict(X_test)
# Print the classification report
cr = classification_report(y_test, y_pred, output_dict=True)
# Plot the confusion matrix
cm = confusion_matrix(y_test, y_pred)
# disp.plot()
# Get the feature importances
importances = model.named_steps["dt"].feature_importances_
# Get the feature names
feature_names = features
# Create a DataFrame
feature_importances = pd.DataFrame(
{"feature": feature_names, "importance": importances}
)
# Sort values
feature_importances = feature_importances.sort_values("importance", ascending=True)
# make a plot instance but not show it
# fig, ax = plt.subplots()
# feature_importances.plot.barh(x="feature", y="importance", ax=ax)
# plt.title("Feature Importances")
# plt.xlabel("Importance")
# plt.ylabel("Feature")
# plt.tight_layout()
# plt.savefig(name + "_feature_importances.png")
# expoert tree to svg
dot_data = tree.export_graphviz(
model["dt"],
out_file=None,
feature_names=features,
class_names=["edible", "poisonous"],
filled=True,
rounded=True,
special_characters=True,
)
graph = graphviz.Source(dot_data)
graph.render(name, format="svg")
# get tree instance and print stats
tree_instance = model["dt"].tree_
# print(f"Number of nodes: {tree_instance.node_count}")
# print(f"Depth of the tree: {tree_instance.max_depth}")
# print(f"Average depth: {tree_instance.node_count / tree_instance.max_depth}")
return cr, cm, feature_importances, graph, tree_instance
def show_classification_report(cr, name):
fig, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(
pd.DataFrame(cr).iloc[:-1, :].T,
annot=True,
cmap="viridis",
fmt=".4f",
linewidths=0.5,
linecolor="black",
ax=ax,
vmax=1,
vmin=0,
)
ax.set_title(name + " Classification Report")
plt.show()
return ax
def show_confusion_matrics(cm, name, model):
fig, ax = plt.subplots(figsize=(5, 5))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=model.classes_)
disp.plot(ax=ax)
plt.title(name + " Confusion Matrix")
def show_feature_importances(fi, name):
fig, ax = plt.subplots(figsize=(10, 20))
fi.plot.barh(x="feature", y="importance", ax=ax)
plt.title(name + " Feature Importances")
plt.xlabel("Importance")
plt.ylabel("Feature")
plt.tight_layout()
plt.show()
def show_tree_stats(tree, name, verbose=True):
if verbose:
print(f"Number of nodes: {tree.node_count}")
print(f"Depth of the tree: {tree.max_depth}")
print(f"Average depth: {tree.node_count / tree.max_depth}")
#show barplot of stats
stats = pd.DataFrame({"Number of nodes": [tree.node_count], "Depth of the tree": [tree.max_depth], "Average depth": [tree.node_count / tree.max_depth], "Name" : name})
if verbose:
stats.plot.barh()
plt.title(name + " Tree Stats")
plt.show()
return stats
Wytrenowanie modeli drzew decyzyjnych [2, 3] z wykorzystaniem przeszukiwania hiperparametrów (np. GridSearch [5]) oraz dobraniem odpowiedniej miary klasyfikacji. Sugerowane 4 hiperparametry: criterion, max_depth, min_samples_leaf, cpp_alpha. W szczególności zwróć uwagę na pruning (cpp_alpha). Analiza wpływu hiperparametrów na jakość wyników
Wizualizacja drzewa oraz analiza drzew dla różnych hiperparametrów. (“Jak różnią się wynikowe drzewa pod wpływem różnych zestawów hiperparametrów?”)
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingGridSearchCV, GridSearchCV
param_grid = {
"dt__criterion": ["gini", "entropy"],
"dt__max_depth": [None, 5, 10, 15, 30],
"dt__min_samples_leaf": [1, 100, 1000],
"dt__ccp_alpha": [0.0, 0.001, 0.005],
}
dt_classifier = DecisionTreeClassifier()
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
# Create the decision tree classifier
# Create the GridSearchCV object
grid_search = HalvingGridSearchCV(model, param_grid, cv=skf, n_jobs=-1, verbose=2, scoring='f1_macro')
# Fit the grid search to the data
grid_search.fit(X_train, y_train)
# Get the best hyperparameters
best_params = grid_search.best_params_
print(best_params)
c:\Users\filip\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\utils\validation.py:1300: DataConversionWarning: A column-vector y was passed when a 1d array was expected. Please change the shape of y to (n_samples, ), for example using ravel(). y = column_or_1d(y, warn=True)
n_iterations: 5
n_required_iterations: 5
n_possible_iterations: 5
min_resources_: 603
max_resources_: 48855
aggressive_elimination: False
factor: 3
----------
iter: 0
n_candidates: 90
n_resources: 603
Fitting 5 folds for each of 90 candidates, totalling 450 fits
----------
iter: 1
n_candidates: 30
n_resources: 1809
Fitting 5 folds for each of 30 candidates, totalling 150 fits
----------
iter: 2
n_candidates: 10
n_resources: 5427
Fitting 5 folds for each of 10 candidates, totalling 50 fits
----------
iter: 3
n_candidates: 4
n_resources: 16281
Fitting 5 folds for each of 4 candidates, totalling 20 fits
----------
iter: 4
n_candidates: 2
n_resources: 48843
Fitting 5 folds for each of 2 candidates, totalling 10 fits
{'dt__ccp_alpha': 0.0, 'dt__criterion': 'entropy', 'dt__max_depth': None, 'dt__min_samples_leaf': 1}
Analiza drzewa stworzonego z wykorzystaniem najlepszych hiperparametrów¶
best_model = grid_search.best_estimator_
best_cr, best_cm, best_fi, best_graph, best_tree = decision_tree_analysis("halving_search", best_model, X_test, y_test)
show_confusion_matrics(best_cm, "halving_search", best_model.named_steps["dt"])
show_classification_report(best_cr, "halving_search")
show_feature_importances(best_fi, "halving_search")
show_tree_stats(best_tree, "halving_search")
SVG(best_graph.pipe(format="svg"))
Number of nodes: 435 Depth of the tree: 28 Average depth: 15.535714285714286
Analiza wpływu maksymalnej głębokości drzewa na wyniki klasyfikacji.¶
# make a tree with 5 max depth
dt_classifier = DecisionTreeClassifier(max_depth=5)
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
# normal fit
model.fit(X_train, y_train)
max_depth_5_cr, max_depth_5_cm, max_depth_5_fi, max_depth_5_graph, max_depth_5_tree = decision_tree_analysis("max_depth_5", model, X_test, y_test)
show_confusion_matrics(max_depth_5_cm, "max_depth_5", model.named_steps["dt"])
show_classification_report(max_depth_5_cr, "max_depth_5")
show_feature_importances(max_depth_5_fi, "max_depth_5")
show_tree_stats(max_depth_5_tree, "max_depth_5")
SVG(max_depth_5_graph.pipe(format="svg"))
Number of nodes: 45 Depth of the tree: 5 Average depth: 9.0
Wpływ innej metryki na wyniki klasyfikacji.¶
dt_classifier = DecisionTreeClassifier(max_depth=5, criterion='entropy')
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
# normal fit
model.fit(X_train, y_train)
max_depth_5_entropy_cr, max_depth_5_entropy_cm, max_depth_5_entropy_fi, max_depth_5_entropy_graph, max_depth_5_entropy_tree = decision_tree_analysis("max_depth_5_entropy", model, X_test, y_test)
show_confusion_matrics(max_depth_5_entropy_cm, "max_depth_5_entropy", model.named_steps["dt"])
show_classification_report(max_depth_5_entropy_cr, "max_depth_5_entropy")
show_feature_importances(max_depth_5_entropy_fi, "max_depth_5_entropy")
show_tree_stats(max_depth_5_entropy_tree, "max_depth_5_entropy")
SVG(max_depth_5_entropy_graph.pipe(format="svg"))
Number of nodes: 39 Depth of the tree: 5 Average depth: 7.8
# 5 max depth and high pruning
dt_classifier = DecisionTreeClassifier(max_depth=5, ccp_alpha=0.01)
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
# normal fit
model.fit(X_train, y_train)
# show the tree
max_depth_5_pruned_cr, max_depth_5_pruned_cm, max_depth_5_pruned_fi, max_depth_5_pruned_graph, max_depth_5_pruned_tree = decision_tree_analysis("max_depth_5_pruned", model, X_test, y_test)
show_confusion_matrics(max_depth_5_pruned_cm, "max_depth_5_pruned", model.named_steps["dt"])
show_classification_report(max_depth_5_pruned_cr, "max_depth_5_pruned")
show_feature_importances(max_depth_5_pruned_fi, "max_depth_5_pruned")
show_tree_stats(max_depth_5_pruned_tree, "max_depth_5_pruned")
SVG(max_depth_5_pruned_graph.pipe(format="svg"))
Number of nodes: 19 Depth of the tree: 5 Average depth: 3.8
Zbadanie jak użycie parametru min_samples_leaf wpływa na wyniki klasyfikacji.¶
dt_classifier = DecisionTreeClassifier(max_depth=5, min_samples_leaf=1000)
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
# normal fit
model.fit(X_train, y_train)
# show the tree
max_depth_5_min_samples_leaf_cr, max_depth_5_min_samples_leaf_cm, max_depth_5_min_samples_leaf_fi, max_depth_5_min_samples_leaf_graph, max_depth_5_min_samples_leaf_tree = decision_tree_analysis("max_depth_5_min_samples_leaf", model, X_test, y_test)
show_confusion_matrics(max_depth_5_min_samples_leaf_cm, "max_depth_5_min_samples_leaf", model.named_steps["dt"])
show_classification_report(max_depth_5_min_samples_leaf_cr, "max_depth_5_min_samples_leaf")
show_feature_importances(max_depth_5_min_samples_leaf_fi, "max_depth_5_min_samples_leaf")
show_tree_stats(max_depth_5_min_samples_leaf_tree, "max_depth_5_min_samples_leaf")
SVG(max_depth_5_min_samples_leaf_graph.pipe(format="svg"))
Number of nodes: 37 Depth of the tree: 5 Average depth: 7.4
Minimalna liczba próbek w liściu¶
dt_classifier = DecisionTreeClassifier(min_samples_leaf=1000)
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
model.fit(X_train, y_train)
min_samples_leaf_cr, min_samples_leaf_cm, min_samples_leaf_fi, min_samples_leaf_graph, min_samples_leaf_tree = decision_tree_analysis("min_samples_leaf", model, X_test, y_test)
show_confusion_matrics(min_samples_leaf_cm, "min_samples_leaf", model.named_steps["dt"])
show_classification_report(min_samples_leaf_cr, "min_samples_leaf")
show_feature_importances(min_samples_leaf_fi, "min_samples_leaf")
show_tree_stats(min_samples_leaf_tree, "min_samples_leaf")
SVG(min_samples_leaf_graph.pipe(format="svg"))
Number of nodes: 71 Depth of the tree: 10 Average depth: 7.1
Minimalny spadek nieczystości¶
dt_classifier = DecisionTreeClassifier(min_impurity_decrease=0.001)
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
model.fit(X_train, y_train)
min_impurity_decrease_cr, min_impurity_decrease_cm, min_impurity_decrease_fi, min_impurity_decrease_graph, min_impurity_decrease_tree = decision_tree_analysis("min_impurity_decrease", model, X_test, y_test)
show_confusion_matrics(min_impurity_decrease_cm, "min_impurity_decrease", model.named_steps["dt"])
show_classification_report(min_impurity_decrease_cr, "min_impurity_decrease")
show_feature_importances(min_impurity_decrease_fi, "min_impurity_decrease")
show_tree_stats(min_impurity_decrease_tree, "min_impurity_decrease")
SVG(min_impurity_decrease_graph.pipe(format="svg"))
Number of nodes: 169 Depth of the tree: 16 Average depth: 10.5625
Zbadanie jak użycie parametru wagi klasy wpływa na wyniki modelu.¶
class_weight = {'e': 0.29, 'p': 0.71} # a może powiniśmy 'skrzywić' w stronę p aby zminimalizować możliwość zatrucia?
dt_classifier = DecisionTreeClassifier(max_depth=5, class_weight=class_weight)
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
# normal fit
model.fit(X_train, y_train)
max_5_class_weight_cr, max_5_class_weight_cm, max_5_class_weight_fi, max_5_class_weight_graph, max_5_class_weight_tree = decision_tree_analysis("max_5_class_weight", model, X_test, y_test)
show_confusion_matrics(max_5_class_weight_cm, "max_5_class_weight", model.named_steps["dt"])
show_classification_report(max_5_class_weight_cr, "max_5_class_weight")
show_feature_importances(max_5_class_weight_fi, "max_5_class_weight")
show_tree_stats(max_5_class_weight_tree, "max_5_class_weight")
SVG(max_5_class_weight_graph.pipe(format="svg"))
Number of nodes: 57 Depth of the tree: 5 Average depth: 11.4
Badanie wpływy cpp_alpha na wyniki modelu.¶
import matplotlib.pyplot as plt
dt_classifier = DecisionTreeClassifier()
model = Pipeline(
[
("preprocessor", preprocessor),
("dt", dt_classifier),
]
)
X_train_preprocessed = model['preprocessor'].transform(X_train)
X_test_preprocessed = model['preprocessor'].transform(X_test)
path = model["dt"].cost_complexity_pruning_path(X_train_preprocessed, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities
fig, ax = plt.subplots()
ax.plot(ccp_alphas[:-1], impurities[:-1], marker='o', drawstyle="steps-post")
ax.set_xlabel("effective alpha")
ax.set_ylabel("total impurity of leaves")
ax.set_title("Total Impurity vs effective alpha for training set")
clfs = []
for ccp_alpha in ccp_alphas:
clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)
clf.fit(X_train_preprocessed, y_train)
clfs.append(clf)
clfs = clfs[:-1]
ccp_alphas = ccp_alphas[:-1]
node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]
fig, ax = plt.subplots(2, 1)
ax[0].plot(ccp_alphas, node_counts, marker='o', drawstyle="steps-post")
ax[0].set_xlabel("alpha")
ax[0].set_ylabel("number of nodes")
ax[0].set_title("Number of nodes vs alpha")
ax[1].plot(ccp_alphas, depth, marker='o', drawstyle="steps-post")
ax[1].set_xlabel("alpha")
ax[1].set_ylabel("depth of tree")
ax[1].set_title("Depth vs alpha")
fig.tight_layout()
Podsumowanie¶
Macierz pomyłek¶
fig, ax = plt.subplots(2, 4, figsize=(20, 10))
labels = best_model.named_steps["dt"].classes_
ConfusionMatrixDisplay(confusion_matrix=max_depth_5_cm, display_labels=labels).plot(
ax=ax[0, 0]
)
ax[0, 0].set_title("max_depth_5")
ConfusionMatrixDisplay(confusion_matrix=max_depth_5_entropy_cm, display_labels=labels).plot(
ax=ax[0, 1]
)
ax[0, 1].set_title("max_depth_5_entropy")
ConfusionMatrixDisplay(
confusion_matrix=max_depth_5_pruned_cm, display_labels=labels
).plot(ax=ax[0, 2])
ax[0, 2].set_title("max_depth_5_pruned")
ConfusionMatrixDisplay(
confusion_matrix=max_depth_5_min_samples_leaf_cm, display_labels=labels
).plot(ax=ax[0, 3])
ax[0, 3].set_title("max_depth_5_min_samples_leaf")
ConfusionMatrixDisplay(
confusion_matrix=min_samples_leaf_cm, display_labels=labels
).plot(ax=ax[1, 0])
ax[1, 0].set_title("min_samples_leaf")
ConfusionMatrixDisplay(
confusion_matrix=min_impurity_decrease_cm, display_labels=labels
).plot(ax=ax[1, 1])
ax[1, 1].set_title("min_impurity_decrease")
ConfusionMatrixDisplay(
confusion_matrix=max_5_class_weight_cm, display_labels=labels
).plot(ax=ax[1, 2])
ax[1, 2].set_title("max_5_class_weight")
ConfusionMatrixDisplay(confusion_matrix=best_cm, display_labels=labels).plot(
ax=ax[1, 3]
)
ax[1, 3].set_title("best")
fig.tight_layout()
plt.show()
Raport klasyfikacji¶
fig, ax = plt.subplots(2, 4, figsize=(20, 10))
sns.heatmap(
pd.DataFrame(max_depth_5_cr).iloc[:-1, :].T,
annot=True,
ax=ax[0][0],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[0][0].set_title("max_depth_5")
sns.heatmap(
pd.DataFrame(max_depth_5_entropy_cr).iloc[:-1, :].T,
annot=True,
ax=ax[0][1],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[0][1].set_title("max_depth_5_entropy")
sns.heatmap(
pd.DataFrame(max_depth_5_pruned_cr).iloc[:-1, :].T,
annot=True,
ax=ax[0][2],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[0][2].set_title("max_depth_5_pruned")
sns.heatmap(
pd.DataFrame(max_depth_5_min_samples_leaf_cr).iloc[:-1, :].T,
annot=True,
ax=ax[0][3],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[0][3].set_title("max_depth_5_min_samples_leaf")
sns.heatmap(
pd.DataFrame(min_samples_leaf_cr).iloc[:-1, :].T,
annot=True,
ax=ax[1][0],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[1][0].set_title("min_samples_leaf")
sns.heatmap(
pd.DataFrame(min_impurity_decrease_cr).iloc[:-1, :].T,
annot=True,
ax=ax[1][1],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[1][1].set_title("min_impurity_decrease")
sns.heatmap(
pd.DataFrame(max_5_class_weight_cr).iloc[:-1, :].T,
annot=True,
ax=ax[1][2],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[1][2].set_title("max_5_class_weight")
sns.heatmap(
pd.DataFrame(best_cr).iloc[:-1, :].T,
annot=True,
ax=ax[1][3],
cmap="viridis",
vmin=0,
vmax=1,
fmt=".4f",
)
ax[1][3].set_title("best")
fig.tight_layout()
plt.show()
Raport ważności cech¶
fig, ax = plt.subplots(2, 4, figsize=(20, 10))
max_depth_5_fi.plot.barh(x="feature", y="importance", ax=ax[0][0])
ax[0][0].set_title("max_depth_5")
ax[0][0].set_yticklabels([])
max_depth_5_entropy_fi.plot.barh(x="feature", y="importance", ax=ax[0][1])
ax[0][1].set_title("max_depth_5_entropy")
ax[0][1].set_yticklabels([])
max_depth_5_pruned_fi.plot.barh(x="feature", y="importance", ax=ax[0][2])
ax[0][2].set_title("max_depth_5_pruned")
ax[0][2].set_yticklabels([])
max_depth_5_min_samples_leaf_fi.plot.barh(x="feature", y="importance", ax=ax[0][3])
ax[0][3].set_title("max_depth_5_min_samples_leaf")
ax[0][3].set_yticklabels([])
min_samples_leaf_fi.plot.barh(x="feature", y="importance", ax=ax[1][0])
ax[1][0].set_title("min_samples_leaf")
ax[1][0].set_yticklabels([])
min_impurity_decrease_fi.plot.barh(x="feature", y="importance", ax=ax[1][1])
ax[1][1].set_title("min_impurity_decrease")
ax[1][1].set_yticklabels([])
max_5_class_weight_fi.plot.barh(x="feature", y="importance", ax=ax[1][2])
ax[1][2].set_title("max_5_class_weight")
ax[1][2].set_yticklabels([])
best_fi.plot.barh(x="feature", y="importance", ax=ax[1][3])
ax[1][3].set_title("best")
ax[1][3].set_yticklabels([])
fig.tight_layout()
plt.show()
max_depth_5_tree_stats = show_tree_stats(max_depth_5_tree, "max_depth_5", False)
max_depth_5_entropy_tree_stats = show_tree_stats(
max_depth_5_entropy_tree, "max_depth_5_entropy", False
)
max_depth_5_pruned_tree_stats = show_tree_stats(
max_depth_5_pruned_tree, "max_depth_5_pruned", False
)
max_depth_5_min_samples_leaf_tree_stats = show_tree_stats(
max_depth_5_min_samples_leaf_tree, "max_depth_5_min_samples_leaf", False
)
min_samples_leaf_tree_stats = show_tree_stats(
min_samples_leaf_tree, "min_samples_leaf", False
)
min_impurity_decrease_tree_stats = show_tree_stats(
min_impurity_decrease_tree, "min_impurity_decrease", False
)
max_5_class_weight_tree_stats = show_tree_stats(
max_5_class_weight_tree, "max_5_class_weight", False
)
best_tree_stats = show_tree_stats(best_tree, "best", False)
# all above _stats are df lets concat it
concated = pd.concat([max_depth_5_tree_stats, max_depth_5_entropy_tree_stats, max_depth_5_pruned_tree_stats, max_depth_5_min_samples_leaf_tree_stats, min_samples_leaf_tree_stats, min_impurity_decrease_tree_stats, max_5_class_weight_tree_stats, best_tree_stats], axis=0)
# compare all stats by number of nodes
# compare 'Number of nodes' for all models
concated.plot.barh(y='Number of nodes', x='Name')
plt.title('Number of nodes for all models')
plt.show()
# compare 'Depth of the tree' for all models
concated.plot.barh(y='Depth of the tree', x='Name')
plt.title('Depth of the tree for all models')
plt.show()
# compare 'Average depth' for all models
concated.plot.barh(y='Average depth', x='Name')
plt.title('Average depth for all models')
plt.show()
FAQ¶
- Co znajduje się w liściach drzewa?
- W liściach drzewa znajdują się klasy, które model przypisuje do obserwacji.
- Czy przycinanie drzewa (pruning) jest potrzebne? Na czym polega ten proces?
- Przycinanie drzewa jest procesem, w którym usuwane są gałęzie, które nie przynoszą korzyści w postaci poprawy wyników modelu. Przycinanie drzewa pozwala na zredukowanie złożoności modelu, co może przyczynić się do poprawy jego generalizacji.
- Czy drzewo możebyćza„duże” lub za „małe”?
- Drzewo może być zbyt duże, jeśli posiada zbyt wiele gałęzi i liści, co może prowadzić do przeuczenia modelu. Z kolei drzewo może być zbyt małe, jeśli posiada zbyt mało gałęzi i liści, co może prowadzić do niedouczenia modelu.
- Czy drzewo decyzyjne potrzebuje normalizacji/standaryzacji/dyskretyzacji danych?
- Drzewo decyzyjne nie wymaga normalizacji/standaryzacji/dyskretyzacji danych, ponieważ jest odporne na różne skale danych.
- Czy model możnaprzeuczyć?
- Model można przeuczyć, jeśli posiada zbyt dużą złożoność, co prowadzi do zbyt dobrego dopasowania do danych treningowych i zbyt słabej generalizacji do nowych danych.
- Na czym polega wagowanie klas?
- W wagowaniu klas przypisuje się różne wagi klasom w celu zrównoważenia wpływu poszczególnych klas na wyniki modelu.
- Domyślnie DecissionTreeClassifier obliza wagi klas na podstawie ilości wystąpień danej klasy w zbiorze treningowym.
- Na czym polega walidacja krzyżowa (ang. cross validation) w algorytmie przeszukiwania hiperparametrów
- Walidacja krzyżowa polega na podziale zbioru danych na k podzbiorów, z których jeden podzbiór jest wykorzystywany jako zbiór testowy, a pozostałe podzbiory jako zbiór treningowy. Proces ten jest powtarzany k razy, a wyniki są uśredniane. Walidacja krzyżowa pozwala na ocenę jakości modelu na różnych podzbiorach danych i zwiększa wiarygodność wyników.